import os
import json
import math
import loguru
import torch
import transformers

from deepspeed        import zero
from typing           import Dict, Optional, List
from torch.utils.data import Dataset
from transformers     import Trainer, GPTQConfig, deepspeed
from accelerate.utils import DistributedType
from dataclasses      import dataclass, field

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
from transformers.trainer_pt_utils               import LabelSmoother
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus

logger          = loguru.logger
IGNORE_TOKEN_ID = LabelSmoother.ignore_index   # -100


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="Qwen/Qwen-7B")


@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )
    eval_data_path: str = field(
        default=None, metadata={"help": "Path to the evaluation data."}
    )
    lazy_preprocess: bool = False


@dataclass
class TrainingArguments(transformers.TrainingArguments):     #! 父类有很多参数
    cache_dir: Optional[str] = field(default=None)
    optim: str               = field(default="adamw_torch")
    model_max_length: int    = field(
        default  = 8192,
        metadata = {
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_lora: bool = False


# TODO: LoRA参数设置，理清楚含义
# QWenLMHeadModel(
#   (transformer): QWenModel(
#     (wte): Embedding(151936, 2048)
#     (drop): Dropout(p=0.0, inplace=False)
#     (rotary_emb): RotaryEmbedding()
#     (h): ModuleList(
#       (0-23): 24 x QWenBlock(
#         (ln_1): RMSNorm()
#         (attn): QWenAttention(
#           (c_attn): Linear(in_features=2048, out_features=6144, bias=True)
#           (c_proj): Linear(in_features=2048, out_features=2048, bias=False)
#           (core_attention_flash): FlashSelfAttention()
#           (attn_dropout): Dropout(p=0.0, inplace=False)
#         )
#         (ln_2): RMSNorm()
#         (mlp): QWenMLP(
#           (w1): Linear(in_features=2048, out_features=5504, bias=False)
#           (w2): Linear(in_features=2048, out_features=5504, bias=False)
#           (c_proj): Linear(in_features=5504, out_features=2048, bias=False)
#         )
#       )
#     )
#     (ln_f): RMSNorm()
#   )
#   (lm_head): Linear(in_features=2048, out_features=151936, bias=False)
# )
@dataclass
class LoraArguments:
    lora_r: int                    = 64
    lora_alpha: int                = 16
    lora_dropout: float            = 0.05
    lora_target_modules: List[str] = field(
        default_factory=lambda: ["c_attn", "c_proj", "w1", "w2"]
    )
    lora_weight_path: str = ""
    lora_bias: str        = "none"
    q_lora: bool          = False


def maybe_zero_3(param):
    if hasattr(param, "ds_id"):
        assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
    else:
        param = param.detach().cpu().clone()
    return param


# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
    if bias == "none":
        to_return = {k: t for k, t in named_params if "lora_" in k}
    elif bias == "all":
        to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
    elif bias == "lora_only":
        to_return = {}
        maybe_lora_bias = {}
        lora_bias_names = set()
        for k, t in named_params:
            if "lora_" in k:
                to_return[k] = t
                bias_name = k.split("lora_")[0] + "bias"
                lora_bias_names.add(bias_name)
            elif "bias" in k:
                maybe_lora_bias[k] = t
        for k, t in maybe_lora_bias:
            if bias_name in lora_bias_names:
                to_return[bias_name] = t
    else:
        raise NotImplementedError
    to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
    return to_return


local_rank = None

def rank0_print(*args):
    if local_rank == 0:
        print(*args)


def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, 
                                   output_dir: str, bias="none"):
    """Collects the state dict and dump to disk."""
    # check if zero3 mode enabled
    if deepspeed.is_deepspeed_zero3_enabled():
        #! trainer.model_wrapped表示被别的工具(如deepspeed)封装过的model
        # 如果没封装过 trainer.model_wrapped is trainer.model
        # 因此_zero3_consolidated_16bit_state_dict是deepspeed的一个方法
        state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
    else:
        if trainer.args.use_lora:
            state_dict = get_peft_state_maybe_zero_3(
                #! generator[name:str, value:torch.nn.parameter.Parameter]
                # 示例如下
                # base_model.model.transformer.h.23.mlp.w1.base_layer.weight torch.Size([5504, 2048])
                # base_model.model.transformer.h.23.mlp.w1.lora_A.default.weight torch.Size([64, 2048])
                # base_model.model.transformer.h.23.mlp.w1.lora_B.default.weight torch.Size([5504, 64])
                # base_model.model.transformer.h.23.mlp.w2.base_layer.weight torch.Size([5504, 2048])
                # base_model.model.transformer.h.23.mlp.w2.lora_A.default.weight torch.Size([64, 2048])
                # base_model.model.transformer.h.23.mlp.w2.lora_B.default.weight torch.Size([5504, 64])
                # base_model.model.transformer.h.23.mlp.c_proj.base_layer.weight torch.Size([2048, 5504])
                # base_model.model.transformer.h.23.mlp.c_proj.lora_A.default.weight torch.Size([64, 5504])
                # base_model.model.transformer.h.23.mlp.c_proj.lora_B.default.weight torch.Size([2048, 64])
                # base_model.model.transformer.ln_f.weight torch.Size([2048])
                # base_model.model.lm_head.weight torch.Size([151936, 2048])
                trainer.model.named_parameters(), bias
            )
        else:
            state_dict = trainer.model.state_dict()
    if trainer.args.should_save and trainer.args.local_rank == 0:
        trainer._save(output_dir, state_dict=state_dict)


def preprocess(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    max_len: int,
    system_message: str = "You are a helpful assistant."
) -> Dict:
    roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}

    im_start   = tokenizer.im_start_id
    im_end     = tokenizer.im_end_id
    nl_tokens  = tokenizer('\n').input_ids                      # [198]
    _system    = tokenizer('system').input_ids    + nl_tokens   # [8948,  198] 
    # _user      = tokenizer('user').input_ids      + nl_tokens   # [872,   198]
    # _assistant = tokenizer('assistant').input_ids + nl_tokens   # [77091, 198]
    
    # Apply prompt templates
    input_ids, targets = [], []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != roles["user"]:  # TODO: 不太懂
            source = source[1:]

        input_id, target = [], []                  #! 模型的输入和自回归魔目标token_id
        system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens  # 1. 先加入系统指令
        input_id += system
        target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens           # 2. 系统指令不算loss，因此加入IGNORE_TOKEN_ID

        assert len(input_id) == len(target)
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            _input_id = tokenizer(role).input_ids + nl_tokens + \
                tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
            input_id += _input_id
            if role == '<|im_start|>user':                                                          # 3. 问题部分也不计算loss, 使用填充
                _target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens
            elif role == '<|im_start|>assistant':                                                   # 4. 回答部分需要计算loss, 使用真实token_id填充
                _target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \
                    _input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens
            else:
                raise NotImplementedError
            target += _target
        assert len(input_id) == len(target)
        input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))                            # 5. padding
        target += [IGNORE_TOKEN_ID] * (max_len - len(target))                                       # 6. 标签使用填充IGNORE_TOKEN_ID
        input_ids.append(input_id[:max_len])
        targets.append(target[:max_len])
    input_ids = torch.tensor(input_ids, dtype=torch.int)
    targets = torch.tensor(targets, dtype=torch.int)

    return dict(
        input_ids=input_ids,
        labels=targets,
        attention_mask=input_ids.ne(tokenizer.pad_token_id),
    )


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int):
        super(SupervisedDataset, self).__init__()

        rank0_print("Formatting inputs...")
        sources = [example["conversations"] for example in raw_data]
        data_dict = preprocess(sources, tokenizer, max_len)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]
        self.attention_mask = data_dict["attention_mask"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(
            input_ids=self.input_ids[i],
            labels=self.labels[i],
            attention_mask=self.attention_mask[i],
        )


class LazySupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int):
        super(LazySupervisedDataset, self).__init__()
        self.tokenizer = tokenizer
        self.max_len = max_len

        rank0_print("Formatting inputs...Skip in lazy mode")
        # self.tokenizer = tokenizer
        self.raw_data = raw_data       # List[Dict]
        self.cached_data_dict = {}     #! 设计了一个缓存，因为get的时候转成tensor了

    def __len__(self):
        return len(self.raw_data)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        if i in self.cached_data_dict:
            return self.cached_data_dict[i]

        # self.raw_data[i]["conversations"] = [
        #         {
        #             "from"  : "user",
        #             "value" : data["input"]
        #         },
        #         {
        #             "from"  : "assistant",
        #             "value" : data["output"]
        #         }
        #     ]
        ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len)
        ret = dict(
            input_ids=ret["input_ids"][0],
            labels=ret["labels"][0],
            attention_mask=ret["attention_mask"][0],
        )
        self.cached_data_dict[i] = ret    

        return ret


def make_supervised_data_module(
    tokenizer: transformers.PreTrainedTokenizer, data_args, max_len,
) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    dataset_cls = (
        #! 前者有cache遇到计算过的可以直接取值，后者每次都需要重新计算
        LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
    )
    rank0_print("Loading data...")

    train_json = json.load(open(data_args.data_path, "r"))
    train_dataset = dataset_cls(train_json, tokenizer=tokenizer, max_len=max_len)

    if data_args.eval_data_path:
        eval_json = json.load(open(data_args.eval_data_path, "r"))
        eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer, max_len=max_len)
    else:
        eval_dataset = None

    return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)


def train():
    global local_rank

    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
    )
    (
        model_args,     # ModelArguments(model_name_or_path='/home/yangxianpku/models/Qwen/Qwen-1_8B-Chat')
        data_args,      # DataArguments(data_path='sft_data.json', eval_data_path=None, lazy_preprocess=True)
        training_args,  # 父类有很多参数
        lora_args,      # LoRA的参数
    ) = parser.parse_args_into_dataclasses()

    # 单卡QLoRA才有
    if getattr(training_args, 'deepspeed', None) and int(os.environ.get("WORLD_SIZE", 1))==1:
        training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED

    local_rank = training_args.local_rank

    device_map = None
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if lora_args.q_lora:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else "auto"
        if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
            logger.warning(
                "FSDP or ZeRO3 are incompatible with QLoRA."    #! QLoRA与ZeRO3不兼容
            )

    #! LoRA与ZeRO3不兼容
    is_chat_model = 'chat' in model_args.model_name_or_path.lower()
    if (
            training_args.use_lora
            and not lora_args.q_lora
            and deepspeed.is_deepspeed_zero3_enabled()
            and not is_chat_model
    ):
        raise RuntimeError("ZeRO3 is incompatible with LoRA when finetuning on base model.")

    model_load_kwargs = {
        'low_cpu_mem_usage': not deepspeed.is_deepspeed_zero3_enabled(),
    }

    # Set RoPE scaling factor
    # "rotary_emb_base": 10000,
    # "rotary_pct": 1.0,
    config = transformers.AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,     #! None
        trust_remote_code=True,
    )
    config.use_cache = False

    # Load model and tokenizer
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=training_args.cache_dir,
        device_map=device_map,
        trust_remote_code=True,
        quantization_config=GPTQConfig(
            bits=4, disable_exllama=True
        )
        if training_args.use_lora and lora_args.q_lora
        else None,
        **model_load_kwargs,
    )
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
        trust_remote_code=True,
    )
    tokenizer.pad_token_id = tokenizer.eod_id

    if training_args.use_lora:
        if lora_args.q_lora or is_chat_model:
            modules_to_save = None
        else:
            modules_to_save = ["wte", "lm_head"]
        lora_config = LoraConfig(
            r=lora_args.lora_r,
            lora_alpha=lora_args.lora_alpha,
            target_modules=lora_args.lora_target_modules,  # TODO: 哪些层需要进行微调
            lora_dropout=lora_args.lora_dropout,
            bias=lora_args.lora_bias,                      # None
            task_type = TaskType.CAUSAL_LM,
            modules_to_save=modules_to_save  # This argument serves for adding new tokens.
        )

        # TODO: 
        if lora_args.q_lora:
            model = prepare_model_for_kbit_training(
                model, use_gradient_checkpointing=training_args.gradient_checkpointing
            )

        model = get_peft_model(model, lora_config)

        # Print peft trainable params
        model.print_trainable_parameters()

        # TODO
        if training_args.gradient_checkpointing:
            model.enable_input_require_grads()

    # Load data
    data_module = make_supervised_data_module(tokenizer=tokenizer, 
                                            data_args=data_args, 
                                            max_len=training_args.model_max_length
                                        )

    # Start trainner
    trainer = Trainer(
        model=model, tokenizer=tokenizer, args=training_args, **data_module
    )

    trainer.train()
    trainer.save_state()

    safe_save_model_for_hf_trainer(trainer=trainer, 
                                output_dir=training_args.output_dir, 
                                bias=lora_args.lora_bias)


if __name__ == "__main__":
    train()

